-
Notifications
You must be signed in to change notification settings - Fork 63
Use newer version of mma_atom and copy_atom in 00_bmg_gemm #540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…d_copy_*, and move tensor/copy initialization to host-side params in to_underlying_arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving with the minor changes suggested above.
Edit -- there is a bug in the TiledCopy handling that needs fixing, described below.
Hi @anamikac-intel, with this PR, I'm encountering the same errors locally as the CI. Thanks! |
Theoretical bf16 peak perf for BMG is 116 TF/s, so the new performance is too high. Either there's a problem in the kernel (not doing the full computation) or something's wrong with the performance computation. |
This comment was marked as outdated.
This comment was marked as outdated.
Fixes a compilation failure found in #540 when >2D tensors are passed to one of the `make_block_2d_copy_*` functions.
using ArchTag = typename DispatchPolicy::ArchTag; | ||
|
||
static_assert(platform::is_same<ElementA, ElementB>::value, "MainloopIntelXeXMX16 requires that A and B have same type."); | ||
static_assert(platform::is_same<ElementA, ElementB>::value, "MainloopXeL1Staged requires that A and B have same type."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the existing MMA collective code, we use variable names ATOM_M
, ATOM_N
, ATOM_K
incorrectly, because they don't correspond to the underlying MMA atom, but to our tiling scheme instead.
static constexpr int ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape());
static constexpr int ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());
static constexpr int ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape());
Workgroup tiles are divided spatially into sub-group fragments/tiles.
For example, the variable ATOM_M
is actually the number of partitions of WG_M
in subgroup tiles that comprise a workgroup tile. i.e. The variable ATOM_M
means WG_M
/SG_M
, and is not representative of the atom's M dimension.
Can we rename these variables in this PR? It's not necessary for correctness, but just for understanding the code.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you, but we should fix it in another PR because our new feature in the latest release strongly depend on this PR, we expect this PR to be merge ASAP.
…ck 2D Copy Utilities
Modify 00_bmg_gemm to include new mma and copy atoms (#477).
00_bmg_gemm combines two parts: mma and epilogue. To add new atom changes, we need to update both parts since they currently use old atoms. As starting we will:
Old Atom:
Problem Size: 5120x4096x4096x1
Cutlass GEMM Performance: [96.448]TFlop/s (1.7813)ms
New Atom:
Problem Size: 5120x4096x4096x1
Cutlass GEMM Performance: [97.259]TFlop/s (1.7664)ms